import torch
import torch.nn as nn

q = torch.randn(1,1,32)
k = torch.randn(1,1,32)
ret = torch.concat([q,k],dim = 2)
print(ret.shape)
Linear1 = nn.Linear(64,32)
print(Linear1(ret))
